{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(770836:140409999501120,MainProcess):2023-06-26-21:55:57.544.845 [mindspore/common/_decorator.py:40] 'Fills' is deprecated from version 2.0.fill' instead.\n",
      "/home/daiyuxin/had/MindNLP/mindnlp/mindnlp/utils/download.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore.dataset import GeneratorDataset, transforms\n",
    "from mindnlp.transforms import NezhaTokenizer, PadTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-06-26 20:04:12--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/test.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 551596 (539K) [text/plain]\n",
      "Saving to: ‘test.txt’\n",
      "\n",
      "test.txt            100%[===================>] 538.67K  43.2KB/s    in 10s     \n",
      "\n",
      "2023-06-26 20:04:24 (53.5 KB/s) - ‘test.txt’ saved [551596/551596]\n",
      "\n",
      "--2023-06-26 20:04:24--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/dev.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 551313 (538K) [text/plain]\n",
      "Saving to: ‘dev.txt’\n",
      "\n",
      "dev.txt             100%[===================>] 538.39K   576KB/s    in 0.9s    \n",
      "\n",
      "2023-06-26 20:04:27 (576 KB/s) - ‘dev.txt’ saved [551313/551313]\n",
      "\n",
      "--2023-06-26 20:04:27--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/train.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 9946122 (9.5M) [text/plain]\n",
      "Saving to: ‘train.txt’\n",
      "\n",
      "train.txt           100%[===================>]   9.49M  19.0KB/s    in 9m 13s  \n",
      "\n",
      "2023-06-26 20:13:41 (17.6 KB/s) - ‘train.txt’ saved [9946122/9946122]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/test.txt\n",
    "!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/dev.txt\n",
    "!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/train.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class SentimentDataset:\n",
    "    \"\"\"Sentiment Dataset\"\"\"\n",
    "\n",
    "    def __init__(self, path):\n",
    "        self.path = path\n",
    "        self._labels, self._text_a = [], []\n",
    "        self._load()\n",
    "\n",
    "    def _load(self):\n",
    "        with open(self.path, \"r\", encoding=\"utf-8\") as f:\n",
    "            dataset = f.read()\n",
    "        lines = dataset.split(\"\\n\")\n",
    "        for line in lines[:-1]:\n",
    "            text_a, label = line.split(\"\\t\")\n",
    "            self._labels.append(int(label))\n",
    "            self._text_a.append(text_a)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self._labels[index], self._text_a[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_dataset(source, tokenizer, pad_value, max_seq_len=64, batch_size=32, shuffle=True):\n",
    "    column_names = [\"label\", \"text_a\"]\n",
    "    rename_columns = [\"label\", \"input_ids\"]\n",
    "    \n",
    "    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)\n",
    "    # transforms\n",
    "    pad_op = PadTransform(max_seq_len, pad_value=pad_value)\n",
    "    type_cast_op = transforms.TypeCast(mindspore.int32)\n",
    "    \n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=[tokenizer, pad_op], input_columns=\"text_a\")\n",
    "    dataset = dataset.map(operations=[type_cast_op], input_columns=\"label\")\n",
    "    # rename dataset\n",
    "    dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)\n",
    "    # batch dataset\n",
    "    dataset = dataset.batch(batch_size)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = NezhaTokenizer.from_pretrained('nezha-cn-base')\n",
    "pad_value = tokenizer.token_to_id('[PAD]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(SentimentDataset(\"train.txt\"), tokenizer, pad_value)\n",
    "dataset_val = process_dataset(SentimentDataset(\"dev.txt\"), tokenizer, pad_value)\n",
    "dataset_test = process_dataset(SentimentDataset(\"test.txt\"), tokenizer, pad_value, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from mindnlp.models.nezha import NezhaConfig, NezhaForSequenceClassification\n",
    "\n",
    "with open(\"../ckpt_ms/nezha-cn-base/config.json\") as f:\n",
    "    config = json.load(f)\n",
    "config = NezhaConfig(**config)\n",
    "config.num_labels = 10\n",
    "model = NezhaForSequenceClassification(config)\n",
    "\n",
    "import mindspore as ms\n",
    "\n",
    "params_dict = ms.load_checkpoint(\"nezha_classfication_epoch_2.ckpt\")\n",
    "params_not_load = ms.load_param_into_net(model, params_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(2758247:140715544528704,MainProcess):2023-06-04-19:05:20.679.142 [mindspore/ops/operations/math_ops.py:4331] The 'NPUAllocFloatStatus' operator will be deprecated in the future. Please don't use it.\n",
      "[WARNING] ME(2758247:140715544528704,MainProcess):2023-06-04-19:05:20.679.707 [mindspore/ops/operations/math_ops.py:4465] The 'NPUClearFloatStatus' operator will be deprecated in the future. Please don't use it.\n",
      "[WARNING] ME(2758247:140715544528704,MainProcess):2023-06-04-19:05:20.680.112 [mindspore/ops/operations/math_ops.py:4401] The 'NPUGetFloatStatus' operator will be deprecated in the future. Please don't use it.\n",
      "[WARNING] ME(2758247:140715544528704,MainProcess):2023-06-04-19:05:20.680.592 [mindspore/common/api.py:843] 'mindspore.ms_class' will be deprecated and removed in a future version. Please use 'mindspore.jit_class' instead.\n"
     ]
    }
   ],
   "source": [
    "from mindnlp._legacy.amp import auto_mixed_precision\n",
    "from mindspore import nn, ops\n",
    "model = auto_mixed_precision(model, 'O1')\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)\n",
    "\n",
    "def forward(input_ids, label):\n",
    "    outputs = model(input_ids)\n",
    "    loss = loss_fn(outputs[0], label)\n",
    "    return loss\n",
    "\n",
    "grad_fn = ops.value_and_grad(forward, None, model.trainable_params())\n",
    "\n",
    "def train_step(input_ids, label):\n",
    "    loss, grads = grad_fn(input_ids, label)\n",
    "    optimizer(grads)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 5625/5625 [24:31<00:00,  3.82it/s, loss=0.27862355]\n",
      "Epoch 1: 100%|██████████| 5625/5625 [24:42<00:00,  3.80it/s, loss=0.11460448] \n",
      "Epoch 2: 100%|██████████| 5625/5625 [24:44<00:00,  3.79it/s, loss=0.06515992]   \n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from mindspore.train.serialization import save_checkpoint\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n",
    "\n",
    "total = dataset_train.get_dataset_size()\n",
    "for epoch in range(3):\n",
    "    with tqdm(total=total) as progress:\n",
    "        progress.set_description(f'Epoch {epoch}')\n",
    "        loss_total = 0\n",
    "        cur_step_nums = 0\n",
    "        for batch, (label, data) in enumerate(dataset_train.create_tuple_iterator()):\n",
    "            loss = train_step(data, label)\n",
    "            loss_total += loss\n",
    "            cur_step_nums += 1\n",
    "            progress.set_postfix(loss=loss_total/cur_step_nums)\n",
    "            progress.update(1)\n",
    "        save_checkpoint(model, f\"nezha_classfication_epoch_{epoch}.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import Tensor\n",
    "\n",
    "def predict(text, label=None):\n",
    "    label_map = {0: '财经', 1: '房产', 2: '股票',\n",
    "                 3: '教育', 4: '科技', 5: '社会',\n",
    "                 6: '时政', 7: '运动', 8: '游戏', 9:'娱乐' \n",
    "                }\n",
    "    text_tokenized = Tensor([tokenizer.encode(text).ids])\n",
    "    logits = model(text_tokenized)\n",
    "    predict_label = logits[0].asnumpy().argmax()\n",
    "    info = f\"inputs: '{text}', predict: '{label_map[predict_label]}'\"\n",
    "    if label is not None:\n",
    "        info += f\" , label: '{label_map[label]}'\"\n",
    "    print(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:41<00:00,  7.63it/s, acc=0.941]\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.metrics import Accuracy\n",
    "from tqdm import tqdm\n",
    "metric = Accuracy()\n",
    "cur_step_nums = 0\n",
    "acc_total = 0\n",
    "total = dataset_val.get_dataset_size()\n",
    "with tqdm(total=total) as progress:\n",
    "    for label, data in dataset_val.create_tuple_iterator():\n",
    "        cur_step_nums += 1\n",
    "        pred = model(data)[0]\n",
    "        metric.update(pred, label)\n",
    "        acc_total += metric.eval()\n",
    "        progress.set_postfix(acc=acc_total/cur_step_nums)\n",
    "        progress.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: '原神是一款由米哈游开发的开放世界冒险', predict: '游戏'\n"
     ]
    }
   ],
   "source": [
    "predict(\"原神是一款由米哈游开发的开放世界冒险\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
